CREATE OR REPLACE FUNCTION "pipeline"."lle"("source_table" text, "out_table" varchar, "id_col" varchar, "feature_cols" varchar, "n_components" int4, "n_neighbors" int4)
  RETURNS "pg_catalog"."text" AS $BODY$
  import numpy as np
  from sklearn.manifold import LocallyLinearEmbedding
  import json
  import time

  #初始化返回值
  result = {
    "status": 0,
    "error_msg": "success",
    "result": {
      "input_params": {
        "out_table": out_table,
        "source_table": source_table,
        "id_col": id_col,
        "feature_cols": [],
        "n_components": n_components,
        "n_neighbors": n_neighbors
      },
      "output_params": []
    }
  }

  #读数据
  def get_data_sql(source_table, id_col, feature_cols):  
    sql_str = "select %s, %s from %s" %(id_col, feature_cols, source_table)
    return sql_str
  
  #获取表格属性信息
  def getTableMetaInfo(table_name):
    tmps = table_name.split(".")
    if len(tmps) > 0:
      table_name = tmps[1]	  
    sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" %(table_name)
    rs = plpy.execute(sql_str)
    meta = {}
    for line in rs:
      meta[line['column_name']] = line['data_type']
    return meta
  
  #创建输出表
  def saveOutTable(datas, out_table, id_col, n_components, meta):
    value_tpl = "({}"
    if meta.get(id_col) in ["text", "character varying", "varchar", "char", "date", "character", 
        "timestamp", "time", "timestamp_with_timezone", "longvarchar", "longnvarchar", "nvarchar", "nchar"]:
      value_tpl = "('{}'"
    #table_name = "%s_%s" %(out_table, int(time.time()))
    table_name = out_table    
    string = str(id_col) + ' ' + meta.get(id_col)
    for i in range(n_components):
      string =  string + ', f' + str(i+1) + ' float'
    sql_str = "CREATE TABLE %s (%s);" %(table_name, string)
    plpy.execute(sql_str)
    string = str(id_col)
    for i in range(n_components):
      string = string + ', f' + str(i+1)
    sql_str = "INSERT INTO %s (%s) VALUES" %(table_name, string)
    _values = []
    for item in datas:
      tmp = value_tpl.format(item[0])
      for i in range(len(item) - 1):
        tmp = tmp + ', ' + str(item[i+1])
      tmp += ')'
      _values.append(tmp)
    sql_str += ",".join(_values)
    plpy.execute(sql_str)
    cols = []
    cols.append(id_col)
    for i in range(n_components):
      cols.append("f" + str(i+1))
    return {"out_table_name": table_name, "cols": cols}

  def throwError(e):
    global result
    result["status"] = 500
    result["error_msg"] = str(e)
  
  #主程序
  def begin_alg(source_table, out_table, id_col, feature_cols, n_components, n_neighbors):
    global result
    sourceTable = source_table
    if source_table.startswith("create view") or source_table.startswith("CREATE VIEW"):
        try:
            rs = plpy.execute(source_table)
            sourceTable = source_table.split(" ")[2]
        except Exception as e:
            plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
            throwError(e)
            return json.dumps(result)    
    sql_str = get_data_sql(sourceTable, id_col, feature_cols)
    features = feature_cols.split(",")
    result["result"]["input_params"]["feature_cols"] = features
    ids = []
    rv = None
    try:
      rv = plpy.execute(sql_str)
    except Exception as e:
      throwError(e)
      return result
    datas = []
    for item in rv:
      line = []
      ids.append(item[id_col])
      for feature in features:
        line.append(float(item[feature.replace('"', "")]))
      datas.append(line)
    datas = np.array(datas, dtype=np.float64)
    lle = None
    try:
      lle = LocallyLinearEmbedding(n_neighbors = n_neighbors, n_components = n_components, method = 'standard')
      res = lle.fit_transform(datas).tolist()
    except Exception as e:
      throwError(e)
      return result

    out_datas = []
    
    for i in range(len(ids)):
      array = []
      array.append(ids[i])
      for item in res[i]:
        array.append(item)
      out_datas.append(array)
    meta = getTableMetaInfo(sourceTable)
    saveRet = None
    try:
      saveRet = saveOutTable(out_datas, out_table, id_col, n_components, meta)
    except Exception as e:
      throwError(e)
      return result
    tmpTable = {
        "out_table_name": saveRet["out_table_name"],
        "output_cols": saveRet["cols"]      
    }
    result["result"]["output_params"].append(tmpTable)   
    #result["result"]["output_params"]["out_table_name"] = saveRet["out_table_name"]
    #result["result"]["output_params"]["output_cols"] = saveRet["cols"]
    
    return json.dumps(result)
  
  if __name__ == '__main__':
    global result
    ret = begin_alg(source_table, out_table, id_col, feature_cols, n_components, n_neighbors)
    return ret
$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100